import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Dict, Any, Optional, Tuple

from hypersense.importance.base_analyzer import BaseImportanceAnalyzer


class NRReliefFAnalyzer(BaseImportanceAnalyzer):
    """
    More robust N-RReliefF:
    - Adaptive sigma (bandwidth estimation based on typical neighbor distance)
    - Neighbor weights based on "actual distance" (not neighbor order)
    - Multiple runs and averaging to reduce randomness
    - Stable non-negative + normalization (with smoothing & fallback)
    - Complete boundary handling (constant y, k>n, updates>n, empty features, etc.)
    """

    def __init__(
        self,
        k: int = 10,
        sigma: Optional[float] = None,      # None => adaptive bandwidth
        updates: str | int = "all",
        n_repeats: int = 1,                 # Average over multiple runs
        repeat_seed: Optional[int] = None,  # Random seed for repeats
        adaptive_sigma_sample: int = 200,   # Subset size for adaptive bandwidth estimation
        verbose: bool = True,               # tqdm progress bar
        smoothing_eps: float = 1e-6,        # Normalization smoothing
        normalize: bool = True,            # Always normalize to [0,1]
    ):
        super().__init__()
        self.k = k
        self.sigma = sigma
        self.updates = updates
        self.n_repeats = max(1, int(n_repeats))
        self.repeat_seed = repeat_seed
        self.adaptive_sigma_sample = adaptive_sigma_sample
        self.verbose = verbose
        self.smoothing_eps = smoothing_eps
        self.normalize = normalize

        self.params_name: Optional[List[str]] = None

    # -------------------- public API --------------------

    def fit(self, configs: List[Dict[str, Any]], scores: List[float]) -> None:
        """
        Fit N-RReliefF using configurations and scores.
        """
        X_df = pd.DataFrame(configs)
        y = np.asarray(scores, dtype=float)

        numerical_cols = X_df.select_dtypes(include=[np.number]).columns.tolist()
        categorical_cols = X_df.select_dtypes(exclude=[np.number]).columns.tolist()

        numericalX = X_df[numerical_cols].to_numpy(dtype=float) if len(numerical_cols) else np.empty((len(X_df), 0))
        categoricalX = (
            pd.get_dummies(X_df[categorical_cols]).to_numpy(dtype=float)
            if categorical_cols else np.empty((len(X_df), 0))
        )

        # Record feature order (numerical + one-hot encoded categorical column names)
        onehot_cat_names = []
        if categorical_cols:
            onehot_cat_names = pd.get_dummies(X_df[categorical_cols]).columns.tolist()

        self.params_name = numerical_cols + onehot_cat_names

        # Adaptive sigma (if not specified)
        if self.sigma is None:
            self.sigma_ = self._estimate_sigma(np.hstack([numericalX, categoricalX]))
        else:
            self.sigma_ = float(self.sigma)

        # Multiple runs and averaging
        rng = np.random.default_rng(self.repeat_seed)
        all_runs = []
        for run_id in range(self.n_repeats):
            seed_now = None if self.repeat_seed is None else int(rng.integers(0, 2**31 - 1))
            w = self._analyse_once(numericalX, categoricalX, y, random_seed=seed_now)
            all_runs.append(w)

        W_mean = np.mean(np.vstack(all_runs), axis=0) if len(all_runs) > 1 else all_runs[0]

        # Stable non-negative + smoothing + normalization + fallback
        W = np.asarray(W_mean, dtype=float)
        
        if self.normalize:        
            if np.allclose(W, W[0]):  # All equal
                W = np.ones_like(W) / len(W)
            else:
                W = W - W.min()
                W = W + self.smoothing_eps  # Smoothing to avoid exact 0

                s = W.sum()
                if s <= 0 or not np.isfinite(s):
                    W = np.ones_like(W) / len(W)
                else:
                    W = W / s
        else:
            # When not normalizing, only apply non-negative correction to avoid negative weights
            W = W - W.min()
            W = W + self.smoothing_eps
        self.feature_importances_ = {name: float(w) for name, w in zip(self.params_name, W)}

    def explain(self) -> Dict[str, float]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")
        return self.feature_importances_

    def explain_interactions(self) -> Dict[Tuple[str, str], float]:
        # Relief-based methods do not directly model interactions
        return {}

    def rank(self) -> List[str]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")
        return sorted(self.feature_importances_, key=lambda k: self.feature_importances_[k], reverse=True)

    # -------------------- single run core --------------------

    def _analyse_once(self, numericalX, categoricalX, y, random_seed: Optional[int]) -> np.ndarray:
        """
        Main ReliefF process for a single run, returns (D,) weights (numerical + one-hot categorical).
        """
        W_num = self._rrelieff(numericalX, y, categoricalx=False, random_seed=random_seed)
        W_cat = self._rrelieff(categoricalX, y, categoricalx=True, random_seed=random_seed)

        W_num = W_num.flatten() if W_num.size else np.array([], dtype=float)
        W_cat = W_cat.flatten() if W_cat.size else np.array([], dtype=float)
        W = np.hstack((W_num, W_cat))

        # If no features, return empty
        if W.size == 0:
            return W

        return W

    # -------------------- adaptive sigma --------------------

    def _estimate_sigma(self, X: np.ndarray) -> float:
        """
        Estimate sigma based on the median of the m-th nearest neighbor distance.
        """
        n, d = X.shape
        if n <= 2 or d == 0:
            return 1.0

        rng = np.random.default_rng(self.repeat_seed)
        m = max(5, int(np.sqrt(n)))                # Use the m-th nearest neighbor
        sample_n = min(self.adaptive_sigma_sample, n)
        idx = rng.choice(n, size=sample_n, replace=False)

        dists = []
        for i in idx:
            diff = X - X[i]
            dist = np.sqrt((diff ** 2).sum(axis=1))
            dist.sort()
            if len(dist) > m:
                dists.append(dist[m])

        mdist_m = np.median(dists) if dists else 1.0
        # Bandwidth should not be too small, set a floor
        sigma = max(mdist_m, 1e-6)
        return float(sigma)

    # -------------------- core ReliefF --------------------

    def _rrelieff(self, X, y, categoricalx=False, random_seed: Optional[int] = None):
        """
        Returns the weight vector for each column shape=(D,1)
        """
        n, d = X.shape
        if d == 0:
            return np.zeros((0, 1))

        # Number of updates
        if self.updates == "all":
            m = n
        else:
            m = int(self.updates)
            m = max(1, min(m, n))

        # k cannot exceed n-1
        k = max(1, min(self.k, n - 1))

        # yRange to prevent division by zero
        yMin, yMax = np.min(y), np.max(y)
        yRange = yMax - yMin
        if yRange <= 0:
            # All scores are equal, return uniform
            return np.ones((d, 1), dtype=float) / d

        rng = np.random.default_rng(random_seed)

        # Accumulators
        N_dC = 0.0
        N_dA = np.zeros((d, 1), dtype=float)
        N_dCanddA = np.zeros((d, 1), dtype=float)

        # Difference function
        diff_func = self._diff_categorical if categoricalx else self._diff_numeric

        it = range(m)
        if self.verbose:
            it = tqdm(it, desc="Fitting N-RReliefF", leave=False)

        for i in it:
            # Randomly select sample or traverse in order
            if self.updates == "all":
                idx = i
            else:
                idx = int(rng.integers(0, n))

            # Find k nearest neighbors (return neighbor matrix, indices, distances)
            XKNN, neigh_idx, neigh_dist = self._knnsearch(X, X[idx, :], k)
            yKNN = y[neigh_idx]
            x0 = X[idx, :]
            y0 = y[idx]

            # Distance-based weights (Gaussian kernel)
            wj = self._distance_weights(neigh_dist, self.sigma_)

            # Accumulate
            dy = np.abs(y0 - yKNN) / yRange   # (k,)
            N_dC += float(np.sum(dy * wj))    # Scalar

            for A in range(d):
                diff = np.array([diff_func(A, x0, XKNN[j], X) for j in range(k)], dtype=float)
                # N_dA[A] accumulate neighbor differences
                N_dA[A] += np.sum(diff * wj)
                # N_dCanddA[A] consider y difference simultaneously
                N_dCanddA[A] += np.sum(dy * wj * diff)

        # Calculate weight for each column
        W_A = np.zeros((d, 1), dtype=float)
        denom2 = (m - N_dC)
        for A in range(d):
            if denom2 == 0:
                W_A[A, 0] = 0.0
            else:
                W_A[A, 0] = (N_dCanddA[A, 0] / N_dC) - ((N_dA[A, 0] - N_dCanddA[A, 0]) / denom2)

        return W_A

    # -------------------- neighbors & weights --------------------

    def _knnsearch(self, A: np.ndarray, b: np.ndarray, k: int):
        """
        Direct Euclidean distance KNN (returns neighbor points, indices, distances).
        """
        A = A.astype(float, copy=False)
        b = b.astype(float, copy=False)
        diff = A - b
        dist = np.sqrt(np.sum(diff * diff, axis=1))
        order = np.argsort(dist)
        # Skip itself (first is usually 0 distance)
        neigh_idx = order[1:k+1]
        return A[neigh_idx], neigh_idx, dist[neigh_idx]

    def _distance_weights(self, dists: np.ndarray, sigma: float) -> np.ndarray:
        """
        Gaussian kernel weights based on distance.
        """
        if sigma <= 0 or not np.isfinite(sigma):
            sigma = 1.0
        w = np.exp(- (dists / sigma) ** 2)
        s = w.sum()
        if s <= 0 or not np.isfinite(s):
            return np.ones_like(dists) / len(dists)
        return w / s

    # -------------------- diff functions --------------------

    def _diff_numeric(self, A, x1, x2, X):
        """
        Numeric difference: normalized by column range (robust)
        """
        col = X[:, A]
        cmin, cmax = np.min(col), np.max(col)
        denom = cmax - cmin
        if denom <= 0:
            return 0.0
        return float(np.abs(x1[A] - x2[A]) / denom)

    def _diff_categorical(self, A, x1, x2, X):
        """
        Categorical difference: 0/1
        """
        return 0.0 if x1[A] == x2[A] else 1.0

